import numpy as np
import pandas as pd
import collections
import torch
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import balanced_accuracy_score, f1_score
from tqdm import tqdm


df_data = pd.read_csv('./feature.GB18030.v10.csv', encoding='GB18030')
df_features = pd.read_csv('./feature_names.csv')
modes_dict = df_data.mode().iloc[0].to_dict()
df_data.fillna(modes_dict, inplace=True)
feature_names = df_features['feature'].tolist()
df_data.dropna(subset=feature_names+['plain_text'], inplace=True)

min_txt_len = 1
filter2 = df_data['plain_text'].apply(lambda x: len(x) >= min_txt_len)
df_data = df_data[filter2]

binary_labels = []
split_x_high = 8.75
split_y_high = 6.0

split_x_low = 2.0
split_y_low = 2.0

for idx, row in df_data.iterrows():
    if df_data['comment_like_num'][idx] <= split_x_low and df_data['child_comment_num'][idx] <= split_y_low:
        binary_labels.append(0)
    elif df_data['comment_like_num'][idx] >= split_x_high and df_data['child_comment_num'][idx] >= split_y_high:
        binary_labels.append(2)
    else:
        binary_labels.append(1)

binary_labels = np.asarray(binary_labels)
print(collections.Counter(binary_labels))

batch_size = 128
LR = 1e-6
epochs = 10000
output_dim = 3

kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold_idx = 0
for train_ids, test_ids in kf.split(df_data['plain_text'], y=binary_labels):
    fold_idx += 1
    train_labels = binary_labels[train_ids]
    bacc = 0
    macro_f1 = 0
    micro_f1 = 0
    for x in range(epochs):
        test_labels = binary_labels[test_ids]
        random.shuffle(test_labels)
        bacc += balanced_accuracy_score(binary_labels[test_ids], test_labels)
        macro_f1 += f1_score(binary_labels[test_ids], test_labels, average='macro')
        micro_f1 += f1_score(binary_labels[test_ids], test_labels, average='micro')
    print(fold_idx, bacc/epochs, micro_f1/epochs, macro_f1/epochs)
